import abc
import pickle
from typing import Any

import torch


class SerializableCallable(abc.ABC):
    @classmethod
    @abc.abstractmethod
    def serialize_compile_artifacts(cls, fn: Any) -> bytes:
        pass

    @classmethod
    @abc.abstractmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> Any:
        pass

    @abc.abstractmethod
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        pass


class BundledAOTAutogradSerializableCallable(SerializableCallable):
    """
    Represents a serializable callable generated by compile_fx.
    This class wraps around the compiled function generated by AOTAutograd.

    TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
    this object should be what's *returned* by aot_module_simplified.
    We'll do that refactor in a later PR.
    """

    def __init__(self, compiled_fn: Any) -> None:
        """
        Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
        of a compiled function generated by AOTAutograd.
        """
        assert hasattr(compiled_fn, "serialize")
        self.compiled_fn = compiled_fn

    def __getattr__(self, attr: Any) -> Any:
        return getattr(self.compiled_fn, attr)

    @classmethod
    def serialize_compile_artifacts(
        cls, fn: "BundledAOTAutogradSerializableCallable"
    ) -> bytes:
        with torch._functorch.config.patch("bundled_autograd_cache", True):
            result = pickle.dumps(fn.compiled_fn.serialize())
            return result

    @classmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> Any:
        from torch._functorch._aot_autograd.aot_autograd_result import (
            deserialize_bundled_cache_entry,
        )

        entry = pickle.loads(data)

        compiled_fn = deserialize_bundled_cache_entry(entry)
        return cls(compiled_fn)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.compiled_fn(*args, **kwargs)
